#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt 
import numpy as np
import time

def rv_gen(theta):
    u = np.random.uniform()
    if u < theta:
        return 1
    return -1

def indicator(x):
    if x == 0:
        return 0.5
    if x > 0:
        return 1
    return 0

def next_state_space(S_t, A_t):
    S_next = S_t.copy()
    for s in S_t:
        for a in A_t:
            if s >= a:
                for xi in [-1,1]:
                    if s + a * xi not in S_next:
                        S_next += [s + a * xi]
    return sorted(S_next)

def bpmf(xi,theta):
    if xi > 0:
        return theta
    return 1 - theta

def cost_function(index_a, xi, c=10):
    return c - A[index_a] * xi

def next_state_index(current_t, current_s, index_a, xi):
    a = A[index_a]
    return S_space[current_t+1].index(current_s + a * xi)


def find_alpha_function(u_vector):
    alpha_function = {}
    index_opt_a = {}
    t = N - 1
    alpha_function_t = np.inf * np.ones([len(Theta), len(S_space[t]), len(A)])
    for index_theta, theta in enumerate(Theta):
        for index_s, s in enumerate(S_space[t]):                    
            for index_a, a in enumerate(A):
                if s >= a:
                    alpha_function_t[index_theta, index_s, index_a] = 1 / (1 - alpha) * max(theta * cost_function(index_a, 1) + (1 - theta) * cost_function(index_a, -1) - u_vector[t], 0) + u_vector[t]
    alpha_function[t] = alpha_function_t
    
    for t in range(N-2,-1,-1):
        index_opt_a_next = 100 * np.ones([len(Theta), len(S_space[t]), len(A)], dtype=np.int8)            
        alpha_function_t = np.inf * np.ones([len(Theta), len(S_space[t]), len(A)])
        
        for index_theta, theta in enumerate(Theta):
            for index_s, s in enumerate(S_space[t]):                    
                for index_a, a in enumerate(A):
                    if s >= a:
                        a_next_index = np.argmin(theta * alpha_function[t+1][index_theta, next_state_index(t, s, index_a, 1), :] + (1 - theta) * alpha_function[t+1][index_theta, next_state_index(t, s, index_a, -1), :])
                        index_opt_a_next[index_theta, index_s, index_a] = a_next_index
                        alpha_function_t[index_theta, index_s, index_a] = 1 / (1 - alpha) * max(theta * cost_function(index_a, 1) + (1 - theta) * cost_function(index_a, -1) - u_vector[t] + theta * alpha_function[t+1][index_theta, next_state_index(t, s, index_a, -1), a_next_index] + (1 - theta) * alpha_function[t+1][index_theta, next_state_index(t, s, index_a, -1), a_next_index], 0) + u_vector[t]
        alpha_function[t] = alpha_function_t
        index_opt_a[t+1] = index_opt_a_next
        
    Q_0 = [sum([mu_0[theta] * alpha_function[0][index_theta, 0, index_a] for index_theta, theta in enumerate(Theta)]) for index_a, a in enumerate(A)]
    index_opt_a[0] = np.argmin(Q_0)
    V_0 = min(Q_0)  
    return (alpha_function, index_opt_a, V_0)

def find_u_gradient(u_vector, alpha_function, index_opt_a):
    partial_alpha_0_alpha ={}
    partial_alpha_0_u ={}
    for index_theta, theta in enumerate(Theta):
        index_s = 0
        index_a = index_opt_a[0]
        partial_alpha_0_alpha[(0,index_theta)] = 1
        for t in range(N):
            s = S_space[t][index_s]
            xi = rv_gen(theta)
            index_s_next = next_state_index(t, s, index_a, xi)
            if t < N-1:
                index_a_next = index_opt_a[t+1][index_theta, index_s, index_a]
                partial_alpha_t_u_t = 1 - 1 / (1 - alpha) * indicator(cost_function(index_a, xi) - u_vector[t] + alpha_function[t+1][index_theta, index_s_next, index_a_next])
                partial_alpha_0_alpha[(t+1,index_theta)] = partial_alpha_0_alpha[(t,index_theta)] * (1 - partial_alpha_t_u_t)
                # pass the state and action to next stage
                index_s = index_s_next
                index_a = index_a_next
            else:
                partial_alpha_t_u_t = 1 - 1 / (1 - alpha) * indicator(cost_function(index_a, xi) - u_vector[t])
            partial_alpha_0_u[(t,index_theta)] = partial_alpha_0_alpha[(t,index_theta)] * partial_alpha_t_u_t
            
    u_gradient = [sum([ mu_0[theta] * partial_alpha_0_u[(t,index_theta)] for index_theta, theta in enumerate(Theta)]) for t in range(N)]
    return u_gradient  

def SGD(u_vector, K = 1000, parm1 = 1, parm2 = 1000, SGD_iter = 20):
    V_opt = np.inf
    # u_opt = u_vector.copy()
    for i in range(K):
        eta = parm1 / (parm2 + i**1)
        alpha_function, index_opt_a, V_0 = find_alpha_function(u_vector)
        # print(u_vector,V_0)
        # print(i)
        if V_opt > V_0:
            V_opt = V_0
            u_opt = u_vector.copy()
            alpha_function_opt = alpha_function.copy()
        for SGD in range(SGD_iter):   
            u_gradient = find_u_gradient(u_vector, alpha_function, index_opt_a)
            # update u vector using SGD
            u_vector = u_vector - eta * np.array(u_gradient)
    return (u_opt, V_opt, alpha_function_opt)

def SGD_1(u_vector, K = 1000, parm1 = 1, parm2 = 1000, SGD_iter = 20):
    V_opt = np.inf
    for i in range(K):
        eta = parm1 / (parm2 + i**1)
        alpha_function = {}
        index_opt_a = {}
        t = N - 1
        alpha_function_t = np.inf * np.ones([len(Theta), len(S_space[t]), len(A)])
        for index_theta, theta in enumerate(Theta):
            for index_s, s in enumerate(S_space[t]):                    
                for index_a, a in enumerate(A):
                    if s >= a:
                        alpha_function_t[index_theta, index_s, index_a] = 1 / (1 - alpha) * max(theta * cost_function(index_a, 1) + (1 - theta) * cost_function(index_a, -1) - u_vector[t], 0) + u_vector[t]
        alpha_function[t] = alpha_function_t
        
        for t in range(N-2,-1,-1):
            index_opt_a_next = 100 * np.ones([len(Theta), len(S_space[t]), len(A)], dtype=np.int8)            
            alpha_function_t = np.inf * np.ones([len(Theta), len(S_space[t]), len(A)])
            
            for index_theta, theta in enumerate(Theta):
                for index_s, s in enumerate(S_space[t]):                    
                    for index_a, a in enumerate(A):
                        if s >= a:
                            a_next_index = np.argmin(theta * alpha_function[t+1][index_theta, next_state_index(t, s, index_a, 1), :] + (1 - theta) * alpha_function[t+1][index_theta, next_state_index(t, s, index_a, -1), :])
                            index_opt_a_next[index_theta, index_s, index_a] = a_next_index
                            alpha_function_t[index_theta, index_s, index_a] = 1 / (1 - alpha) * max(theta * cost_function(index_a, 1) + (1 - theta) * cost_function(index_a, -1) - u_vector[t] + theta * alpha_function[t+1][index_theta, next_state_index(t, s, index_a, -1), a_next_index] + (1 - theta) * alpha_function[t+1][index_theta, next_state_index(t, s, index_a, -1), a_next_index], 0) + u_vector[t]
            alpha_function[t] = alpha_function_t
            index_opt_a[t+1] = index_opt_a_next
            
        Q_0 = [sum([mu_0[theta] * alpha_function[0][index_theta, 0, index_a] for index_theta, theta in enumerate(Theta)]) for index_a, a in enumerate(A)]
        index_opt_a[0] = np.argmin(Q_0)
        V_0 = min(Q_0)  
        if V_opt > V_0:
            V_opt = V_0
            u_opt = u_vector.copy()
            alpha_function_opt = alpha_function.copy()
        # print(u_vector,V_0)
        
        #update u vector using SGD
        for SGD in range(SGD_iter):
            partial_alpha_0_alpha ={}
            partial_alpha_0_u ={}
            for index_theta, theta in enumerate(Theta):
                index_s = 0
                index_a = index_opt_a[0]
                partial_alpha_0_alpha[(0,index_theta)] = 1
                for t in range(N):
                    s = S_space[t][index_s]
                    xi = rv_gen(theta)
                    index_s_next = next_state_index(t, s, index_a, xi)
                    if t < N-1:
                        index_a_next = index_opt_a[t+1][index_theta, index_s, index_a]
                        partial_alpha_t_u_t = 1 - 1 / (1 - alpha) * indicator(cost_function(index_a, xi) - u_vector[t] + alpha_function[t+1][index_theta, index_s_next, index_a_next])
                        partial_alpha_0_alpha[(t+1,index_theta)] = partial_alpha_0_alpha[(t,index_theta)] * (1 - partial_alpha_t_u_t)
                        # pass the state and action to next stage
                        index_s = index_s_next
                        index_a = index_a_next
                    else:
                        partial_alpha_t_u_t = 1 - 1 / (1 - alpha) * indicator(cost_function(index_a, xi) - u_vector[t])
                    partial_alpha_0_u[(t,index_theta)] = partial_alpha_0_alpha[(t,index_theta)] * partial_alpha_t_u_t
                    
            u_gradient = [sum([ mu_0[theta] * partial_alpha_0_u[(t,index_theta)] for index_theta, theta in enumerate(Theta)]) for t in range(N)]
            u_vector = u_vector - eta * np.array(u_gradient)
    return (u_opt, V_opt, alpha_function_opt)
# Algorithm: alpha function approximation for a given u vector (of size N)

# alpha_policy evaluated in the true environment
def g_1(s,a,xi):
    return s + a * xi


def cost_func(a, xi, c=10):
    return c - a * xi

def mu_space(n,h,weight,weight_round =4):
    # n is the number of 
    if round(weight,weight_round) < 0:
        return []
    if n == 1:
        return [[round(weight,weight_round)]]
    l = []
    for i in np.arange(0,weight+h,h):
        l += [[i]+ll for ll in mu_space(n-1,h,weight-i)]
    return l
    
def g_2(mu, xi):
    mu_next = {}
    for theta in Theta:
        mu_next[theta] = mu[theta] * bpmf(xi,theta)
    mass = sum(mu_next.values())
    for theta in Theta:
        mu_next[theta] = mu_next[theta] / mass
    #projection into Mu_space
    #l2 norm or KL divergence
    # can be speeded up
    s = np.inf
    for m in Mu:
        ss = sum([(mu_next[theta] - m[theta])**2 for theta in Theta])
        if ss < s:
            mu_proj = m
            s = ss
    return mu_proj

def alpha_policy_evaluation(alpha_function):
    V_DP = {}
    pi = {}
    V_DP[N] = {}
    for s in S_space[N]:
        for index_mu, mu in enumerate(Mu):
            V_DP[N][(s,index_mu)] = 0
    for t in range(N-1,-1,-1):
        V_DP[t] = {}
        pi[t] = {}
        for s in S_space[t]:
            for index_mu, mu in enumerate(Mu):
                V_opt = float('inf')
                index_s = S_space[t].index(s)
                for index_a, a in enumerate(A):
                    if s >= a:
                        V = sum([alpha_function[t][index_theta, index_s, index_a] * mu[theta] for index_theta, theta in enumerate(Theta)])
                        if V < V_opt:
                            V_opt = V
                            a_opt = a
                pi[t][(s,index_mu)] = a_opt
                s_next_w = g_1(s,a_opt,1)
                s_next_l = g_1(s,a_opt,-1)
                index_mu_next_w = posterior_transition_matrix[(index_mu, 1)]
                index_mu_next_l = posterior_transition_matrix[(index_mu, -1)]
                value = (cost_func(a_opt, 1) + V_DP[t+1][(s_next_w, index_mu_next_w)]) * theta_c + (cost_func(a_opt, -1) + V_DP[t+1][(s_next_l,index_mu_next_l)]) * (1-theta_c)
                V_DP[t][(s,index_mu)] = value
            
            
    # print(V_DP[0][(s_0,mu_list.index(mu_0))])
    return V_DP[0][(s_0,Mu.index(mu_0))]


# MLE
# theta_mle = sum([1 for xi in data if xi>0])/len(data) 
def DP_mle(theta_mle):
    V_DP = {}
    pi = {}
    V_DP[N] = {}
    for s in S_space[N]:
        V_DP[N][s] = 0
    for t in range(N-1,-1,-1):
        V_DP[t] = {}
        pi[t] = {}
        for s in S_space[t]:                   
            V_opt = float('inf')
            for a in A:
                if s >= a:
                    V = (cost_func(a, 1) + V_DP[t+1][g_1(s,a,1)]) * theta_mle + (cost_func(a, -1) + V_DP[t+1][g_1(s,a,-1)]) * (1-theta_mle)
                    if V < V_opt:
                        V_opt = V
                        a_opt = a
            V_DP[t][s] = V_opt
            pi[t][s] = a_opt
    return (pi, V_DP[0][s_0])

def DP_mle_policy_evaluation(pi_mle, theta_c):
    V_DP = {}
    V_DP[N] = {}
    for s in S_space[N]:
        V_DP[N][s] = 0
    for t in range(N-1,-1,-1):
        V_DP[t] = {}
        for s in S_space[t]:                   
            a = pi_mle[t][s]
            V_DP[t][s] = (cost_func(a, 1) + V_DP[t+1][g_1(s,a,1)]) * theta_c + (cost_func(a, -1) + V_DP[t+1][g_1(s,a,-1)]) * (1-theta_c)
            
    return V_DP[0][s_0]

def DP_DRO_1():
    V_DP = {}
    pi = {}
    V_DP[N] = {}
    for s in S_space[N]:
        V_DP[N][s] = 0
    for t in range(N-1,-1,-1):
        V_DP[t] = {}
        pi[t] = {}
        for s in S_space[t]:   
            V_opt = -float('inf')
            for theta in Theta:            
                V_opt_theta = float('inf')
                for a in A:
                    if s >= a:
                        V = (cost_func(a, 1) + V_DP[t+1][g_1(s,a,1)]) * theta + (cost_func(a, -1) + V_DP[t+1][g_1(s,a,-1)]) * (1-theta)
                        if V < V_opt_theta:
                            V_opt_theta = V
                            a_opt_theta = a
                if V_opt_theta > V_opt:
                    V_opt = V_opt_theta
                    a_opt = a_opt_theta
            V_DP[t][s] = V_opt
            pi[t][s] = a_opt
    return (pi, V_DP[0][s_0])
# pi_mle = DP_mle(theta_mle)
# V_0_mle = DP_mle_policy_evaluation(pi_mle, theta_c)

def rho_metric(V_distribution, metric, q):
    # VaR
    if metric == 'VaR':
        s = 0
        for V in sorted(list(V_distribution.keys())):
            s += V_distribution[V]
            if s >= q:
                return V
    # CVaR
    if metric == 'CVaR':
        s = 0
        numerator = 0
        denominator = 0
        for V in sorted(list(V_distribution.keys())):
            s += V_distribution[V]
            if s >= q:
                numerator += V * V_distribution[V]
                denominator += V_distribution[V]
        return numerator / denominator    
    # # expectation
    if metric == 'mean':
        return sum([V * V_distribution[V] for V in list(V_distribution.keys())])


def DP_BRMDP(q, metric = 'CVaR'):
    V_DP = {}
    pi = {}
    V_DP[N] = {}
    for s in S_space[N]:
        for index_mu, mu in enumerate(Mu):
            V_DP[N][(s,index_mu)] = 0
    for t in range(N-1,-1,-1):
        V_DP[t] = {}
        pi[t] = {}
        for s in S_space[t]:
            for index_mu, mu in enumerate(Mu):
                V_opt = float('inf')
                for index_a, a in enumerate(A):
                    if s >= a:
                        V_distribution = {}
                        for theta in Theta:
                            value = 0
                            for xi in [-1,1]:
                                index_mu_next = posterior_transition_matrix[(index_mu, xi)]
                                s_next = g_1(s,a,xi)
                                value += (cost_func(a, xi) +V_DP[t+1][(s_next,index_mu_next)])  * bpmf(xi,theta)
                            try:
                                V_distribution[value] += mu[theta]
                            except:
                                V_distribution[value] = mu[theta]
                        V = rho_metric(V_distribution, metric, q)
                        if V < V_opt:
                            V_opt = V
                            a_opt = a
                V_DP[t][(s,index_mu)] = V_opt
                pi[t][(s,index_mu)] = a_opt
            
            
    # print(V_DP[0][(s_0,mu_list.index(mu_0))])
    return (pi, V_DP[0][(s_0,Mu.index(mu_0))])

def DP_BRMDP_evaluation(pi, theta_c):
    V_DP = {}
    V_DP[N] = {}
    for s in S_space[N]:
        for index_mu, mu in enumerate(Mu):
            V_DP[N][(s,index_mu)] = 0
    for t in range(N-1,-1,-1):
        V_DP[t] = {}
        for s in S_space[t]:
            for index_mu, mu in enumerate(Mu):
                a = pi[t][(s,index_mu)]
                value = 0
                for xi in [-1,1]:
                    index_mu_next = posterior_transition_matrix[(index_mu, xi)]
                    s_next = g_1(s,a,xi)
                    value += (cost_func(a, xi) +V_DP[t+1][(s_next,index_mu_next)])  * bpmf(xi,theta_c)
                
                V_DP[t][(s,index_mu)] = value

    return V_DP[0][(s_0,Mu.index(mu_0))]


def prior_update(mu_t, data):
    mu = mu_t.copy()
    for theta in Theta:
        for xi in data:
            mu[theta] = mu[theta] * bpmf(xi,theta)
    mass = sum(mu.values())
    for theta in Theta:
        mu[theta] = mu[theta] / mass
    #projection into Mu_space
    #l2 norm or KL divergence
    # can be speeded up
    s = np.inf
    for m in Mu:
        ss = sum([(mu[theta] - m[theta])**2 for theta in Theta])
        if ss < s:
            mu_proj = m
            s = ss
    return mu_proj


def DP_DRO():
    V_opt = -np.inf
    for theta in Theta:
        pi_theta, V_0_theta = DP_mle(theta)
        if V_opt < V_0_theta:
            V_opt = V_0_theta
            pi_opt = pi_theta.copy()
    return (pi_opt, V_opt)


# parameter and preparation
N = 6 # time horizon
s_0 = 60
A = [0,1,2,5,10]
index_s_0 = 0
S_space = {}
S_space[0] = [s_0]
for t in range(1,N+1):
    S_space[t] = next_state_space(S_space[t-1], A)

Theta = [0.1,0.3,0.45,0.55,0.7,0.9]
mu_0 = {}
for theta in Theta:
    mu_0[theta] = 1/len(Theta)

Mu_space = mu_space(len(Theta),0.1,1)
Mu = []
for x in Mu_space:
    mu = {}
    for i in range(len(Theta)):
        theta = Theta[i]
        mu[theta] = x[i]
    Mu += [mu]

posterior_transition_matrix = {}
for index_mu, mu in enumerate(Mu):
    for xi in [-1,1]:
        mu_next = g_2(mu, xi)
        posterior_transition_matrix[(index_mu, xi)] = Mu.index(mu_next)

# for xi in data:
#     mu_0 = g_2(mu_0, xi)


# comparison:
value_approx_DP_stats = {}
value_mle_stats = {}
value_BRMDP_stats = {}
times_approx_DP = []
times_BRMDP = []
times_MLE = []
pi_DRO, V_0_DRO_star = DP_DRO()
value_DRO_stats = {}
for theta_c in [0.45,0.55]:
    np.random.seed(1)
    for alpha in [0.4,0.6]:
        value_approx_DP = []
        value_mle = []
        value_BRMDP = []
        value_DRO = []
        for i in range(100):
            print('theta: ',theta_c, 'alpha: ', alpha, 'iteration: ',i)
            data = [rv_gen(theta_c) for i in range(10)]
            mu_0 = {}
            for theta in Theta:
                mu_0[theta] = 1/len(Theta)
            mu_0 = prior_update(mu_0, data)
            u_vector = np.array([max(10 * N - 10 * i,0) for i in range(N)])
            t_before = time.time()
            u_vector, V_0_approx_DP_star, alpha_function = SGD_1(u_vector, K = 100, parm1 = 100, parm2 = 1)
            t_after = time.time()
            times_approx_DP += [t_after - t_before]
            V_0_approx_DP = alpha_policy_evaluation(alpha_function)
            value_approx_DP += [V_0_approx_DP]
            theta_mle = sum([1 for xi in data if xi>0])/len(data) 
            t_before = time.time()
            pi_mle, V_0_mle_star = DP_mle(theta_mle)
            t_after = time.time()
            times_MLE += [t_after - t_before]
            V_0_mle = DP_mle_policy_evaluation(pi_mle, theta_c)
            value_mle += [V_0_mle]
            t_before = time.time()
            pi_BRMDP, V_0_BRMDP_star = DP_BRMDP(q = alpha)
            t_after = time.time()
            times_BRMDP += [t_after - t_before]
            V_0_BRMDP = DP_BRMDP_evaluation(pi_BRMDP, theta_c)
            value_BRMDP += [V_0_BRMDP]
            V_0_DRO = DP_mle_policy_evaluation(pi_DRO, theta_c)
            value_DRO += [V_0_DRO]
            print('approx DP V_0: ', V_0_approx_DP, 'MLE V_0: ', V_0_mle, 'Exact BRMDP V_0: ', V_0_BRMDP, 'theta_MLE: ', theta_mle)
            
        # print(np.mean(value_approx_DP), np.std(value_approx_DP))
        # print(np.mean(value_mle), np.std(value_mle))
        value_approx_DP_stats[(theta_c, alpha)] = (np.mean(value_approx_DP), np.std(value_approx_DP))
        value_mle_stats[(theta_c, alpha)] = (np.mean(value_mle), np.std(value_mle))
        value_BRMDP_stats[(theta_c, alpha)] = (np.mean(value_BRMDP), np.std(value_BRMDP))
        value_DRO_stats[(theta_c, alpha)] = (np.mean(value_DRO), np.std(value_DRO))




theta_c = 0.55
value_BRMDP = []
value_DRO = []
times_BRMDP_1 = []
times_DRO = []

np.random.seed(1)
for i in range(100):
    print(theta_c,alpha, i)
    data = [rv_gen(theta_c) for i in range(10)]
    mu_0 = {}
    for theta in Theta:
        mu_0[theta] = 1/len(Theta)
    mu_0 = prior_update(mu_0, data)
    t_before = time.time()
    pi_BRMDP, V_0_BRMDP_star = DP_BRMDP(q = 0.9999)
    t_after = time.time()
    times_BRMDP_1 += [t_after - t_before]
    V_0_BRMDP = DP_BRMDP_evaluation(pi_BRMDP, theta_c)
    value_BRMDP += [V_0_BRMDP]
    t_before = time.time()
    pi_DRO, V_0_DRO_star = DP_DRO()
    t_after = time.time()
    times_DRO += [t_after - t_before]
    V_0_DRO = DP_mle_policy_evaluation(pi_DRO, theta_c)
    value_DRO += [V_0_DRO]

value_BRMDP_stats[(theta_c, 1)] = (np.mean(value_BRMDP), np.std(value_BRMDP))
value_DRO_stats[(theta_c, 1)] = (np.mean(value_DRO), np.std(value_DRO))
mean_time = {}
mean_time['DRO'] = np.mean(times_DRO)
mean_time['Exact BRMDP alpha = 1'] = np.mean(times_BRMDP_1)
mean_time['Exact BRMDP'] = np.mean(times_BRMDP)
mean_time['Approximate BRMDP'] = np.mean(times_approx_DP)
mean_time['MLE'] = np.mean(times_MLE)

    